from io import BytesIO

import numpy as np
from docx import Document, ImagePart
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P
from docx.table import Table, _Cell
from docx.text.paragraph import Paragraph
from PIL import Image
from rapidocr_onnxruntime import RapidOCR
from tqdm import tqdm


def doc2text(filepath):
    ocr = RapidOCR()
    doc = Document(filepath)
    resp = ""

    def iter_block_items(parent):
        from docx.document import Document

        if isinstance(parent, Document):
            parent_elm = parent.element.body
        elif isinstance(parent, _Cell):
            parent_elm = parent._tc
        else:
            raise ValueError("RapidOCRDocLoader parse fail")

        for child in parent_elm.iterchildren():
            if isinstance(child, CT_P):
                yield Paragraph(child, parent)
            elif isinstance(child, CT_Tbl):
                yield Table(child, parent)

    b_unit = tqdm(
        total=len(doc.paragraphs) + len(doc.tables),
        desc="RapidOCRDocLoader block index: 0",
    )
    for i, block in enumerate(iter_block_items(doc)):
        b_unit.set_description("RapidOCRDocLoader  block index: {}".format(i))
        b_unit.refresh()
        if isinstance(block, Paragraph):
            resp += block.text.strip() + "\n"
            images = block._element.xpath(".//pic:pic")  # 获取所有图片
            for image in images:
                for img_id in image.xpath(".//a:blip/@r:embed"):  # 获取图片id
                    part = doc.part.related_parts[img_id]  # 根据图片id获取对应的图片
                    if isinstance(part, ImagePart):
                        image = Image.open(BytesIO(part._blob))
                        result, _ = ocr(np.array(image))
                        if result:
                            ocr_result = [line[1] for line in result]
                            resp += "\n".join(ocr_result)
        elif isinstance(block, Table):
            for row in block.rows:
                for cell in row.cells:
                    for paragraph in cell.paragraphs:
                        resp += paragraph.text.strip() + "\n"
        b_unit.update(1)
    return resp
